#!/usr/bin/env python3
"""
run_correlation.py

Compute the correlation between per‑link flip counts and local Wilson‑loop
fluctuations for a given gauge‑field ensemble on an L×L lattice.  This script
is adapted from the upstream implementation and is intended to run
stand‑alone within this repository.  It reads a YAML configuration file to
locate the flip counts and gauge configurations, enumerates plaquette loops
for the specified loop sizes, computes Wilson‑loop values for each gauge
configuration and loop, and correlates the resulting variances with the
flip‑count vector.  A bootstrap procedure is used to estimate confidence
intervals on the Pearson correlation coefficient.
"""

import os
import yaml
import numpy as np
import pandas as pd
from scipy.stats import pearsonr


def load_config(path: str) -> dict:
    """Load a YAML configuration file."""
    with open(path, "r") as f:
        return yaml.safe_load(f)


def lattice_links(L: int) -> np.ndarray:
    """Return an array of ((x,y),mu) tuples for a periodic L×L lattice."""
    links = []
    for x in range(L):
        for y in range(L):
            for mu in (0, 1):
                links.append(((x, y), mu))
    return np.array(links, dtype=object)


def precompute_loops_by_size(L: int, loop_sizes: list, links: np.ndarray):
    """
    Precompute all plaquette loops of given sizes on an L×L lattice and build
    an index of which loops involve each link.  Each loop is represented as a
    list of (link_index, orientation) pairs, where orientation is ±1.
    """
    link_index = {(pos[0], pos[1], mu): idx for idx, (pos, mu) in enumerate(links)}
    loops_by_size = {}
    for size in loop_sizes:
        loops = []
        loops_by_link = {i: [] for i in range(len(links))}
        for x0 in range(L):
            for y0 in range(L):
                loop = []
                # east edges
                for s in range(size):
                    idx = link_index[((x0 + s) % L, y0, 0)]
                    loop.append((idx, +1))
                # north edges
                for s in range(size):
                    idx = link_index[((x0 + size) % L, (y0 + s) % L, 1)]
                    loop.append((idx, +1))
                # west edges
                for s in range(size):
                    idx = link_index[((x0 + size - s) % L, (y0 + size) % L, 0)]
                    loop.append((idx, -1))
                # south edges
                for s in range(size):
                    idx = link_index[(x0, (y0 + size - s) % L, 1)]
                    loop.append((idx, -1))
                loop_idx = len(loops)
                loops.append(loop)
                for link_idx, _ in loop:
                    loops_by_link[link_idx].append(loop_idx)
        loops_by_size[size] = {"loops": loops, "loops_by_link": loops_by_link}
    return loops_by_size


def load_gauge_configuration(path: str, group: str, L: int) -> np.ndarray:
    """
    Load a gauge configuration from disk.  The configuration is expected to
    have shape (L,L,2,d,d), where d is the dimension of the gauge group.  For
    U1, d=1; for SU2, d=2; and for SU3, d=3.  This function raises an error
    if the file does not exist or has the wrong shape.
    """
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing gauge configuration: {path}")
    cfg = np.load(path, allow_pickle=False)
    d = {"U1": 1, "SU2": 2, "SU3": 3}[group]
    expected_shape = (L, L, 2, d, d)
    if cfg.shape != expected_shape:
        raise ValueError(
            f"Bad config shape for {path}: got {cfg.shape}, expected {expected_shape}"
        )
    return cfg.astype(complex)


def compute_loop_value(cfg: np.ndarray, loop: list):
    """
    Compute the Wilson‑loop value by multiplying link matrices around the loop.
    The input `cfg` has shape (L,L,2,d,d).  Each element of `loop` is
    (link_index, orientation), where orientation is +1 for forward and -1 for
    backward (complex conjugate transpose).  The product is traced at the end.
    """
    L, _, _, d, _ = cfg.shape
    M = np.eye(d, dtype=complex)
    for idx, orient in loop:
        x = idx // (2 * L)
        y = (idx // 2) % L
        mu = idx % 2
        mat = cfg[x, y, mu]
        M = M @ (mat if orient == +1 else mat.conj().T)
    return np.trace(M)


def bootstrap_ci(x: np.ndarray, y: np.ndarray, resamples: int = 1000, alpha: float = 0.05):
    """
    Compute a bootstrap confidence interval for the Pearson correlation
    coefficient between two vectors.  Returns the lower and upper bounds of
    the (1‑alpha) interval.
    """
    N = len(x)
    rs = []
    for _ in range(resamples):
        i = np.random.randint(0, N, N)
        r, _ = pearsonr(x[i], y[i])
        rs.append(r)
    lower = np.percentile(rs, 100 * alpha / 2)
    upper = np.percentile(rs, 100 * (1 - alpha / 2))
    return lower, upper


def main() -> None:
    repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    cfg = load_config(os.path.join(repo_root, "config.yaml"))

    fc_path = os.path.join(repo_root, cfg["data"]["flip_counts_path"])
    flip_counts = np.load(fc_path).astype(float)
    n_links = flip_counts.size
    L = int(np.sqrt(n_links / 2))
    if 2 * L * L != n_links:
        raise ValueError(
            f"flip_counts length {n_links} is not 2*L^2 for integer L"
        )

    links = lattice_links(L)
    loop_sizes = cfg["parameters"]["loop_sizes"]
    loops_by_size = precompute_loops_by_size(L, loop_sizes, links)
    gauge_groups = cfg["parameters"]["gauge_groups"]
    ensemble = cfg["parameters"]["ensemble_size"]

    results = []
    for group in gauge_groups:
        for size in loop_sizes:
            loops = loops_by_size[size]["loops"]
            loops_by_l = loops_by_size[size]["loops_by_link"]
            vals_per_link = [[] for _ in range(len(links))]
            for t in range(ensemble):
                fname = f"{group}_cfg_{t:03d}.npy"
                fpath = os.path.join(repo_root, cfg["data"]["gauge_configs_dir"], fname)
                cfg_mat = load_gauge_configuration(fpath, group, L)
                Wp = [compute_loop_value(cfg_mat, loop) for loop in loops]
                for i in range(len(links)):
                    for li in loops_by_l[i]:
                        vals_per_link[i].append(abs(Wp[li]))

            var_i = np.array([np.var(v) for v in vals_per_link], dtype=float)

            # guard against constant inputs
            if np.allclose(np.std(var_i), 0) or np.allclose(np.std(flip_counts), 0):
                print(
                    f"⚠️  Skipping {group}, loop_size={size} (constant input → undefined correlation)"
                )
                continue

            r, p = pearsonr(flip_counts, var_i)
            lower, upper = bootstrap_ci(
                flip_counts,
                var_i,
                resamples=cfg["analysis"]["bootstrap_resamples"],
                alpha=cfg["analysis"]["significance_level"],
            )
            results.append(
                {
                    "gauge_group": group,
                    "loop_size": size,
                    "r": r,
                    "p_value": p,
                    "r_ci_lower": lower,
                    "r_ci_upper": upper,
                }
            )
            print(
                f"Finished {group}, loop_size={size}: r={r:.4f}, p={p:.3e}, CI=({lower:.4f},{upper:.4f})"
            )

    # save to CSV
    df = pd.DataFrame(results).dropna(subset=["r"])
    out_csv = os.path.join(repo_root, cfg["results"]["output_csv"])
    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    df.to_csv(out_csv, index=False)


if __name__ == "__main__":
    main()